Exploratory Image AnalysisΒΆ

To familiarize ourselves with the dataset, we explore basic statistics and class characteristics of the dataset to evaluate it quantitatively and qualitatively.

We examine visual features that YOLOv8 relies on at low-level (edges, textures, color), mid-level (object parts such as beards and hats), and high-level (humanoid structure and red–white gestalt). This helps us identify potential biases and confusion sources, particularly with Santa-like hard negatives (e.g. the Grinch).

Guidelines for the project

  • data inspection: how many samples, classes, labels, class imbalances

  • object classes, distributions, statistics, imbalanes, bias .. consider when training, weighted loss

  • qualitative and quantitative sense of data

  • dicuss dataset challenges and how to overcome them

  • visualize image labels visualize bounding boxes

Connect to Gdrive and load the Roboflow datasetΒΆ

InΒ [1]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

%cd /content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

Mounted at /content/drive/
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipython-input-1413283218.py", line 4, in <cell line: 0>
    get_ipython().run_line_magic('cd', '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC')
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2418, in run_line_magic
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "<decorator-gen-85>", line 2, in cd
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/magic.py", line 187, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
                              ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/magics/osm.py", line 342, in cd
    oldcwd = os.getcwd()
             ^^^^^^^^^^^
OSError: [Errno 107] Transport endpoint is not connected

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OSError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 248, in wrapped
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 1769, in getinnerframes
    traceback_info = getframeinfo(tb, context)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 1714, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
               ^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 970, in getsourcefile
    module = getmodule(object, filename)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 999, in getmodule
    file = getabsfile(object, _filename)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 983, in getabsfile
    return os.path.normcase(os.path.abspath(_filename))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen posixpath>", line 415, in abspath
OSError: [Errno 107] Transport endpoint is not connected
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipython-input-1413283218.py", line 4, in <cell line: 0>
    get_ipython().run_line_magic('cd', '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC')
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2418, in run_line_magic
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "<decorator-gen-85>", line 2, in cd
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/magic.py", line 187, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
                              ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/magics/osm.py", line 342, in cd
    oldcwd = os.getcwd()
             ^^^^^^^^^^^
OSError: [Errno 107] Transport endpoint is not connected

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OSError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3575, in run_code
    self.showtraceback(running_compiled_code=True)
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2101, in showtraceback
    stb = self.InteractiveTB.structured_traceback(etype,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1367, in structured_traceback
    return FormattedTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1267, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1124, in structured_traceback
    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
    last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 382, in find_recursion
    return len(records), 0
           ^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TypeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 248, in wrapped
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 1769, in getinnerframes
    traceback_info = getframeinfo(tb, context)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 1714, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
               ^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 970, in getsourcefile
    module = getmodule(object, filename)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 999, in getmodule
    file = getabsfile(object, _filename)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 983, in getabsfile
    return os.path.normcase(os.path.abspath(_filename))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen posixpath>", line 415, in abspath
OSError: [Errno 107] Transport endpoint is not connected
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipython-input-1413283218.py", line 4, in <cell line: 0>
    get_ipython().run_line_magic('cd', '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/MC')
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2418, in run_line_magic
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "<decorator-gen-85>", line 2, in cd
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/magic.py", line 187, in <lambda>
    call = lambda f, *a, **k: f(*a, **k)
                              ^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/magics/osm.py", line 342, in cd
    oldcwd = os.getcwd()
             ^^^^^^^^^^^
OSError: [Errno 107] Transport endpoint is not connected

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OSError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3575, in run_code
    self.showtraceback(running_compiled_code=True)
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2101, in showtraceback
    stb = self.InteractiveTB.structured_traceback(etype,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1367, in structured_traceback
    return FormattedTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1267, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1124, in structured_traceback
    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
    last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 382, in find_recursion
    return len(records), 0
           ^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TypeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell
    return runner(coro)
           ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3492, in run_ast_nodes
    self.showtraceback()
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2101, in showtraceback
    stb = self.InteractiveTB.structured_traceback(etype,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1367, in structured_traceback
    return FormattedTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1267, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1142, in structured_traceback
    formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1082, in format_exception_as_a_whole
    last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 382, in find_recursion
    return len(records), 0
           ^^^^^^^^^^^^
TypeError: object of type 'NoneType' has no len()

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2099, in showtraceback
    stb = value._render_traceback_()
          ^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'TypeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 1101, in get_records
    return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 248, in wrapped
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 1769, in getinnerframes
    traceback_info = getframeinfo(tb, context)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 1714, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
               ^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 970, in getsourcefile
    module = getmodule(object, filename)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 999, in getmodule
    file = getabsfile(object, _filename)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 983, in getabsfile
    return os.path.normcase(os.path.abspath(_filename))
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen posixpath>", line 415, in abspath
OSError: [Errno 107] Transport endpoint is not connected
InΒ [2]:
from helpers import (
    load_roboflow_data
    analyze_dataset,
    visualize_images,
)

from google.colab import userdata
api_key = userdata.get('ROBOFLOW_API_KEY')

!pip install roboflow
from roboflow import Roboflow
rf = Roboflow(api_key=api_key)
project = rf.workspace('dlbs-xi5zk').project('santa-qqpxm')
version = project.version(10)
dataset = version.download('yolov8')
Requirement already satisfied: roboflow in /usr/local/lib/python3.12/dist-packages (1.2.11)
Requirement already satisfied: certifi in /usr/local/lib/python3.12/dist-packages (from roboflow) (2025.11.12)
Requirement already satisfied: idna==3.7 in /usr/local/lib/python3.12/dist-packages (from roboflow) (3.7)
Requirement already satisfied: cycler in /usr/local/lib/python3.12/dist-packages (from roboflow) (0.12.1)
Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.4.9)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.12/dist-packages (from roboflow) (3.10.0)
Requirement already satisfied: numpy>=1.18.5 in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.0.2)
Requirement already satisfied: opencv-python-headless==4.10.0.84 in /usr/local/lib/python3.12/dist-packages (from roboflow) (4.10.0.84)
Requirement already satisfied: Pillow>=7.1.2 in /usr/local/lib/python3.12/dist-packages (from roboflow) (11.3.0)
Requirement already satisfied: pi-heif<2 in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.1.1)
Requirement already satisfied: pillow-avif-plugin<2 in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.5.2)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.9.0.post0)
Requirement already satisfied: python-dotenv in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.2.1)
Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.32.4)
Requirement already satisfied: six in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.17.0)
Requirement already satisfied: urllib3>=1.26.6 in /usr/local/lib/python3.12/dist-packages (from roboflow) (2.5.0)
Requirement already satisfied: tqdm>=4.41.0 in /usr/local/lib/python3.12/dist-packages (from roboflow) (4.67.1)
Requirement already satisfied: PyYAML>=5.3.1 in /usr/local/lib/python3.12/dist-packages (from roboflow) (6.0.3)
Requirement already satisfied: requests-toolbelt in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.0.0)
Requirement already satisfied: filetype in /usr/local/lib/python3.12/dist-packages (from roboflow) (1.2.0)
Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (1.3.3)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (4.61.1)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (25.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib->roboflow) (3.2.5)
Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->roboflow) (3.4.4)
loading Roboflow workspace...
loading Roboflow project...

Check Dataset Content and StructureΒΆ

We inspect images qualitatively, perform class counts and annotation consistency.

InΒ [13]:
import os

# Content of the dataset directory
dataset_path = dataset.location
split_counts = {}

# List all items
for item in sorted(os.listdir(dataset_path)):
    item_path = os.path.join(dataset_path, item)

    if os.path.isdir(item_path):
        if item in ['train', 'valid', 'test']:
            image_dir = os.path.join(item_path, 'images')
            num_files = len(os.listdir(image_dir))
            split_counts[item] = num_files

# Calculate and display statistics
total = sum(split_counts.values())
print(f"\nπŸ”Ž \nTotal Images (Train + Valid + Test): {total}")
print("Count per split:")
for split in ['train', 'valid', 'test']:
    count = split_counts.get(split, 0)
    pct = (count / total * 100) if total > 0 else 0
    print(f"- {split.capitalize()}: {count} images ({pct:.1f}%)")

# Load and display the data.yaml file
import yaml
yaml_path = os.path.join(dataset_path, "data.yaml")

if os.path.exists(yaml_path):
    with open(yaml_path, 'r') as f:
        data_config = yaml.safe_load(f)

    print(f"\nπŸ“‹ \nDataset configuration:")
    print(f"- Classes: {data_config.get('names', [])}")
    print(f"- Number of classes: {data_config.get('nc', 'N/A')}")
    print(f"- Train images: {data_config.get('train', 'N/A')}")
    print(f"- Val images: {data_config.get('val', 'N/A')}")
    print(f"- Test images: {data_config.get('test', 'N/A')}")
else:
    print(f"\n⚠️ data.yaml not found at {yaml_path}")
πŸ”Ž 
Total Images (Train + Valid + Test): 950
Count per split:
- Train: 700 images (73.7%)
- Valid: 155 images (16.3%)
- Test: 95 images (10.0%)

πŸ“‹ 
Dataset configuration:
- Classes: ['Santa']
- Number of classes: 1
- Train images: ../train/images
- Val images: ../valid/images
- Test images: ../test/images
InΒ [3]:
# Usage:
analyze_dataset(dataset.location)
πŸ”Ž Total: 950 | Train: 700 (73.7%) | Valid: 155 (16.3%) | Test: 95 (10.0%)
πŸ“‹ Classes: ['Santa'] | NC: 1
Out[3]:
{'train': 700, 'valid': 155, 'test': 95}
InΒ [18]:
# confirm the images and labels folders exist
!ls -F Santa-10/train/
images/  labels/
InΒ [14]:
# Check images
!ls -F Santa-10/test/images | head -n 5
0-38862400_1671944344_santa_jpg.rf.55c6e9ce609b22e4451a311482828cd1.jpg
102_Santa_jpg.rf.8f184416dd01520cdc916a754e3001e6.jpg
146_Santa_jpg.rf.194951c92edbafce3d8144c398a907f5.jpg
156_Santa_jpg.rf.6ae01fa7adaefdd768712c55de29f1d5.jpg
174_Santa_jpg.rf.95ae299d0f0ef1d76006a4863c8e4b3e.jpg
InΒ [15]:
# Check their annotations - same name, but the file ending is .txt
!ls -F Santa-10/test/labels/ | head -n 5
0-38862400_1671944344_santa_jpg.rf.55c6e9ce609b22e4451a311482828cd1.txt
102_Santa_jpg.rf.8f184416dd01520cdc916a754e3001e6.txt
146_Santa_jpg.rf.194951c92edbafce3d8144c398a907f5.txt
156_Santa_jpg.rf.6ae01fa7adaefdd768712c55de29f1d5.txt
174_Santa_jpg.rf.95ae299d0f0ef1d76006a4863c8e4b3e.txt

Visualize images with bounding boxesΒΆ

InΒ [19]:
DATASET_PATH = './Santa-10'
InΒ [Β ]:
 

Bounding boxesΒΆ

InΒ [Β ]:
!pip install supervision -q
import cv2
import supervision as sv
import glob
import random
import os
import numpy as np
import matplotlib.pyplot as plt

CLASS_MAP = {0: 'Santa'}

def load_yolo_detections(label_path, W, H):
    detections = None

    if os.path.exists(label_path):
        boxes = []
        class_ids = []

        with open(label_path, 'r') as f:
            for line in f:
                parts = line.split()
                if len(parts) < 5:
                    continue

                c, x, y, w, h = map(float, parts)

                x1 = int((x - w/2) * W)
                y1 = int((y - h/2) * H)
                x2 = int((x + w/2) * W)
                y2 = int((y + h/2) * H)

                boxes.append([x1, y1, x2, y2])
                class_ids.append(int(c))

        if boxes:
            detections = sv.Detections(
                xyxy=np.array(boxes),
                class_id=np.array(class_ids)
            )

    return detections

def visualize_images(image_paths, split_name):
    if not image_paths:
        print(f'No images found in {split_name}.')
        return

    sample = random.sample(image_paths, min(10, len(image_paths)))
    fig, axes = plt.subplots(5, 2, figsize=(16, 30))
    fig.suptitle(f"{split_name.capitalize()} Split - {len(sample)} Images", fontsize=16)

    box_annotator = sv.BoxAnnotator()
    label_annotator = sv.LabelAnnotator()

    for i, path in enumerate(sample):
        img = cv2.imread(path)
        if img is None:
            continue

        H, W, _ = img.shape
        label_path = path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
        detections = load_yolo_detections(label_path, W, H)

        if detections:
            img = box_annotator.annotate(scene=img, detections=detections)
            labels = [CLASS_MAP.get(int(cid), f'Class {cid}') for cid in detections.class_id]
            img = label_annotator.annotate(scene=img, detections=detections, labels=labels)

        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        image_id = os.path.splitext(os.path.basename(path))[0][:30]  # Truncate to 30 chars
        count = len(detections) if detections else 0

        row, col = i // 2, i % 2
        axes[row, col].imshow(img_rgb)
        axes[row, col].set_title(f'{image_id} ({W}x{H}) - {count} boxes', fontsize=10)
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.subplots_adjust(top=0.96)
    plt.show()


for split in ['train', 'test', 'valid']:
    image_paths = [p for p in glob.glob(f'{DATASET_PATH}/{split}/images/*')
                   if p.lower().endswith(('.jpg', '.png'))]
    visualize_images(image_paths, split)
    print('---')
Output hidden; open in https://colab.research.google.com to view.

TODOΒΆ

  • Comment on the images, if they are a good example for variety to train a robust model.
  • Fix labels, where needed

Desired result: the bounding boxes are accurately placed and the associated class labels (Santa) is correct. We checked for images with missing annotations (false negatives) or mislabeled objects.

Annotation heatmapΒΆ

Shows where most of the annotations are. Color gradients signify the number of annotations per grid cell.

InΒ [Β ]:
 

TO DO: Resize images to 500x500 all, with bboxes, before making the heatmap. So maybe make the heatmap when we are done plotting distribution of image sizes in px and visualizing image examples and canny edgesΒΆ

InΒ [Β ]:
import os
import cv2
import glob
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# --- PATHS ---
DATASET_PATH = 'Santa-8/train'
IMAGE_DIR = os.path.join(DATASET_PATH, 'images')
LABEL_DIR = os.path.join(DATASET_PATH, 'labels')

# --- LOAD ALL IMAGES AND LABELS ---
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*.jpg'))
print(f"Found {len(image_paths)} images\n")

# --- CALCULATE AVERAGE IMAGE DIMENSIONS ---
heights, widths = [], []

for img_path in image_paths:
    img = cv2.imread(img_path)
    if img is not None:
        h, w, _ = img.shape
        heights.append(h)
        widths.append(w)

H = int(np.mean(heights))
W = int(np.mean(widths))

print(f"Average image dimensions: {W}x{H}")
print(f"Height range: {min(heights)} - {max(heights)}")
print(f"Width range: {min(widths)} - {max(widths)}\n")

# Create heatmap
heatmap = np.zeros((H, W), dtype=np.float32)

# --- ACCUMULATE BBOXES INTO HEATMAP ---
bbox_count = 0

for img_path in image_paths:
    img = cv2.imread(img_path)
    if img is None:
        continue

    img_h, img_w, _ = img.shape

    # Get corresponding label file
    label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'

    if not os.path.exists(label_path):
        continue

    # Read bboxes
    with open(label_path, 'r') as f:
        for line in f:
            parts = line.split()
            if len(parts) < 5:
                continue

            class_id, x, y, w, h = map(float, parts)

            # Convert normalized YOLO coords to pixel coords (for original image)
            x1 = int((x - w/2) * img_w)
            y1 = int((y - h/2) * img_h)
            x2 = int((x + w/2) * img_w)
            y2 = int((y + h/2) * img_h)

            # Scale bbox to average dimensions
            x1_scaled = int(x1 * W / img_w)
            y1_scaled = int(y1 * H / img_h)
            x2_scaled = int(x2 * W / img_w)
            y2_scaled = int(y2 * H / img_h)

            # Clip to heatmap bounds
            x1_scaled = max(0, x1_scaled)
            y1_scaled = max(0, y1_scaled)
            x2_scaled = min(W, x2_scaled)
            y2_scaled = min(H, y2_scaled)

            # Add bbox area to heatmap
            heatmap[y1_scaled:y2_scaled, x1_scaled:x2_scaled] += 1
            bbox_count += 1

print(f"Total bboxes: {bbox_count}")
print(f"Heatmap range: {heatmap.min():.1f} - {heatmap.max():.1f}\n")

# --- CALCULATE STATISTICS ---
q1 = np.quantile(heatmap, 0.25)
median = np.quantile(heatmap, 0.50)
q3 = np.quantile(heatmap, 0.75)
min_val = heatmap.min()
max_val = heatmap.max()

stats_text = f"""# of Annotations Per Grid
Min: {int(min_val)}
Q1: {int(q1)}
Median: {int(median)}
Q3: {int(q3)}
Max: {int(max_val)}"""

# --- VISUALIZE HEATMAP ---
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle(f'Annotation Density Heatmap - Train Split ({bbox_count} bboxes)', fontsize=14)

# Heatmap
im = axes[0].imshow(heatmap, cmap='hot', interpolation='bilinear')
axes[0].set_title('Bbox Density (darker = more annotations)', fontsize=12)
axes[0].set_xlabel('Width')
axes[0].set_ylabel('Height')
cbar = plt.colorbar(im, ax=axes[0], label='Count')

# Add text box with statistics
axes[0].text(0.02, 0.98, stats_text, transform=axes[0].transAxes,
            fontsize=10, verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))

# Contour plot
axes[1].contourf(heatmap, levels=15, cmap='hot')
axes[1].set_title('Annotation Contours', fontsize=12)
axes[1].set_xlabel('Width')
axes[1].set_ylabel('Height')

plt.tight_layout()
plt.show()

# --- STATISTICS ---
print("\nπŸ“Š Annotation Statistics:")
print(stats_text)
print(f"\nImage dimensions: {W}x{H}")
print(f"Mean annotations per pixel: {heatmap.mean():.4f}")
print(f"Coverage: {(heatmap > 0).sum() / (H * W) * 100:.1f}% of image")
Found 590 images

Average image dimensions: 500x500
Height range: 500 - 500
Width range: 500 - 500

Total bboxes: 343
Heatmap range: 4.0 - 272.0

No description has been provided for this image
πŸ“Š Annotation Statistics:
# of Annotations Per Grid
Min: 4
Q1: 91
Median: 156
Q3: 220
Max: 272

Image dimensions: 500x500
Mean annotations per pixel: 153.7937
Coverage: 100.0% of image

Distribution of labelsΒΆ

InΒ [Β ]:
IMAGE_EXT = '*.jpg'
LABEL_EXT = '*.txt'
TARGET_CLASS = 0

# Get all images
image_paths = glob.glob(os.path.join(DATASET_PATH, 'images', IMAGE_EXT))
label_paths = [os.path.join(DATASET_PATH, 'labels', os.path.basename(p).rsplit('.',1)[0] + '.txt')
               for p in image_paths]

total_images = len(image_paths)
santa_object_count = 0
images_with_santa = 0
images_without_santa = 0

# Iterate over all labels
for lbl_path in label_paths:
    if os.path.exists(lbl_path):
        lines = open(lbl_path, 'r').read().splitlines()
        class_ids = [int(line.split()[0]) for line in lines if line.strip()]
        santa_count_in_image = class_ids.count(0)  # Assuming class 0 = Santa
        santa_object_count += santa_count_in_image
        if santa_count_in_image > 0:
            images_with_santa += 1
        else:
            images_without_santa += 1
    else:
        images_without_santa += 1  # No label file = background only

# Results
print(f'Total images: {total_images}')
print(f'Total Santa objects (positive examples): {santa_object_count}')
print(f'Images containing Santa: {images_with_santa}')
print(f'Images without Santa: {images_without_santa}')
print(f'Percentage of images with Santa: {images_with_santa / total_images * 100:.2f}%')
print(f'Percentage of background-only images: {images_without_santa / total_images * 100:.2f}%')
Total images: 590
Total Santa objects (positive examples): 366
Images containing Santa: 335
Images without Santa: 255
Percentage of images with Santa: 56.78%
Percentage of background-only images: 43.22%

335 is the total number of bounding boxes labeled as Santa across the dataset. Some images contain more than one Santa. There are 314 images that contain at least one Santa.

InΒ [Β ]:
# Images with more than one bounding box
multi_bbox_images = []
multi_bbox_labels = []

# Find images with more than one bounding box
for img_path, lbl_path in zip(image_paths, label_paths):
    if os.path.exists(lbl_path):
        lines = open(lbl_path, 'r').read().splitlines()
        bbox_count = len([line for line in lines if line.strip()])

        if bbox_count > 1:
            multi_bbox_images.append(img_path)
            multi_bbox_labels.append(lbl_path)

print(f"Total images: {len(image_paths)}")
print(f"Images with >1 bounding box: {len(multi_bbox_images)}\n")
Total images: 590
Images with >1 bounding box: 17

InΒ [Β ]:
# Visualize multi-bbox images
sample_size = min(10, len(multi_bbox_images))
sample_indices = np.random.choice(len(multi_bbox_images), sample_size, replace=False)

fig, axes = plt.subplots(5, 2, figsize=(14, 18))
fig.suptitle(f"Images with Multiple Bounding Boxes (showing {sample_size})", fontsize=16)

for idx, sample_idx in enumerate(sample_indices):
    img_path = multi_bbox_images[sample_idx]
    lbl_path = multi_bbox_labels[sample_idx]

    # Load image
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    H, W, _ = img.shape

    # Read bounding boxes
    with open(lbl_path, 'r') as f:
        lines = f.read().splitlines()

    bbox_count = len([line for line in lines if line.strip()])

    # Draw bounding boxes
    for line in lines:
        if not line.strip():
            continue
        parts = line.split()
        class_id = int(parts[0])
        x, y, w, h = map(float, parts[1:5])

        # Convert normalized YOLO coords to pixel coords
        x1 = int((x - w/2) * W)
        y1 = int((y - h/2) * H)
        x2 = int((x + w/2) * W)
        y2 = int((y + h/2) * H)

        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(img, f'Class {class_id}', (x1, y1-10),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

    # Plot
    row, col = idx // 2, idx % 2
    axes[row, col].imshow(img)
    axes[row, col].set_title(f"{os.path.basename(img_path)}\n({bbox_count} boxes)", fontsize=10)
    axes[row, col].axis('off')

plt.tight_layout()
plt.show()
No description has been provided for this image

Image characteristics: edges, color schemeΒΆ

InΒ [Β ]:
# Do the images have clear edges?
# Select a random image

# Find all images in the train split and randomly select 3
all_train_images = glob.glob(os.path.join(DATASET_PATH, 'images', IMAGE_EXT))

# Filter images that contain only Santa class ---
santa_images = []

# Loop through all images
for img_path in all_train_images:
    # Find the corresponding label file (replace 'images' with 'labels')
    label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'

    if os.path.exists(label_path):
        # Read the label file
        with open(label_path, 'r') as f:
            labels = [int(line.split()[0]) for line in f.readlines()]

            # Check if the labels are not empty and all labels are "Santa" (class_id == TARGET_CLASS)
            if labels and all(label == TARGET_CLASS for label in labels):  # All labels must be Santa
                santa_images.append(img_path)

# Print the filtered images that only contain "Santa" class
print(f"Found {len(santa_images)} images containing only Santa class.")

# --- Randomly select 3 images with Santa ---
sample_images = random.sample(santa_images, 3)

for img_path in sample_images:
    print(f"Image: {os.path.basename(img_path)}")

    # Load image
    img = cv2.imread(img_path)

    # Canny Edge Detection
    edges = cv2.Canny(img, 100, 200)

    # Convert BGR -> RGB for display
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Plot original and edges
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    axes[0].imshow(img_rgb)
    axes[0].set_title('Original Image')
    axes[0].axis('off')

    axes[1].imshow(edges, cmap='gray')
    axes[1].set_title('Canny Edges (100, 200)')
    axes[1].axis('off')

    plt.show()
Found 312 images containing only Santa class.
Image: 588_Santa_jpg.rf.4db1fd8a1ffc7a45d5463ca951572759.jpg
No description has been provided for this image
Image: christmas-1903109_1280_jpg.rf.d2fe55a682d00dd853206b6c80efbfbb.jpg
No description has been provided for this image
Image: 384_Santa_jpg.rf.7bc450fb07597ade4ab2dae811d8d5d5.jpg
No description has been provided for this image
InΒ [Β ]:
# Find all images and filter for Santa-only
all_train_images = glob.glob(os.path.join(DATASET_PATH, 'images', IMAGE_EXT))
santa_images = []

for img_path in all_train_images:
    label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'

    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            labels = [int(line.split()[0]) for line in f.readlines()]

            if labels and all(label == TARGET_CLASS for label in labels):
                santa_images.append(img_path)

print(f"Found {len(santa_images)} images containing only Santa class.\n")

# Select 5 random images and apply Canny edge detection
sample_images = random.sample(santa_images, 5)

# Create figure with 5 rows (one per image) and 2 columns (original, Canny)
fig, axes = plt.subplots(5, 2, figsize=(8,10))
fig.suptitle('Canny Edge Detection', fontsize=16)

for row, img_path in enumerate(sample_images):
    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Canny Edge Detection
    edges = cv2.Canny(img, 100, 200)

    # Original
    axes[row, 0].imshow(img_rgb)
    axes[row, 0].set_title(f'{os.path.basename(img_path)}', fontsize=10)
    axes[row, 0].axis('off')

    # Canny
    axes[row, 1].imshow(edges, cmap='gray')
    axes[row, 1].set_title(f'Canny (100, 200)', fontsize=10)
    axes[row, 1].axis('off')

plt.tight_layout()
plt.show()
Found 312 images containing only Santa class.

No description has been provided for this image

For Santa Claus images, we want thresholds that are high enough to capture the strong boundaries of the red suit, hat, and belt (the mid-level features) but low enough to capture the edges of the white beard and fur trim without picking up excessive background noise.

If objects have blurry, poor edges we can expect potential training issues, while high-contrast objectswith strong edges are easier for model to learn.

InΒ [Β ]:
import glob
import numpy as np
import matplotlib.pyplot as plt

# --- PATHS ---
IMAGE_DIR = os.path.join(DATASET_PATH, 'images')
LABEL_DIR = os.path.join(DATASET_PATH, 'labels')

# --- COUNT OBJECTS PER IMAGE ---
image_paths = glob.glob(os.path.join(IMAGE_DIR, '*.jpg'))
object_counts = []
images_without_objects = 0

for img_path in image_paths:
    label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'

    count = 0
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            count = len([line for line in f if line.strip()])

    object_counts.append(count)
    if count == 0:
        images_without_objects += 1

object_counts = np.array(object_counts)

# --- STATISTICS ---
stats = {
    'Total Images': len(image_paths),
    'Images without objects': images_without_objects,
    'Images with objects': len(image_paths) - images_without_objects,
    'Total objects': int(object_counts.sum()),
    'Mean objects/image': f"{object_counts.mean():.2f}",
    'Median objects/image': int(np.median(object_counts)),
    'Min objects': int(object_counts.min()),
    'Max objects': int(object_counts.max()),
    'Std Dev': f"{object_counts.std():.2f}"
}

print("πŸ“Š Object Count Statistics:")
print("-" * 40)
for key, val in stats.items():
    print(f"{key:.<30} {val}")
print()

# --- HISTOGRAM ---
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Object Distribution Analysis - Train Split', fontsize=16, fontweight='bold')

# 1. Histogram of object counts
ax = axes[0, 0]
bins = range(0, int(object_counts.max()) + 2)
ax.hist(object_counts, bins=bins, color='steelblue', edgecolor='black', alpha=0.7)
ax.axvline(object_counts.mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {object_counts.mean():.2f}')
ax.axvline(np.median(object_counts), color='green', linestyle='--', linewidth=2, label=f'Median: {int(np.median(object_counts))}')
ax.set_xlabel('Objects per Image', fontsize=11)
ax.set_ylabel('Number of Images', fontsize=11)
ax.set_title('Histogram: Objects per Image', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

# 2. Box plot
ax = axes[0, 1]
bp = ax.boxplot(object_counts, vert=True, patch_artist=True)
bp['boxes'][0].set_facecolor('lightblue')
ax.set_ylabel('Objects per Image', fontsize=11)
ax.set_title('Box Plot: Object Distribution', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3, axis='y')

# Add statistics text on box plot
textstr = f"Q1: {int(np.quantile(object_counts, 0.25))}\nMedian: {int(np.median(object_counts))}\nQ3: {int(np.quantile(object_counts, 0.75))}"
ax.text(1.15, object_counts.mean(), textstr, fontsize=10,
        bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# 3. Cumulative distribution
ax = axes[1, 0]
sorted_counts = np.sort(object_counts)
cumulative = np.arange(1, len(sorted_counts) + 1) / len(sorted_counts) * 100
ax.plot(sorted_counts, cumulative, marker='o', linewidth=2, markersize=4, color='steelblue')
ax.fill_between(sorted_counts, cumulative, alpha=0.3, color='steelblue')
ax.set_xlabel('Objects per Image', fontsize=11)
ax.set_ylabel('Cumulative % of Images', fontsize=11)
ax.set_title('Cumulative Distribution', fontsize=12, fontweight='bold')
ax.grid(alpha=0.3)

# 4. Class distribution pie chart (if multiple classes)
ax = axes[1, 1]
class_counts = [0] * 10  # Support up to 10 classes

for img_path in image_paths:
    label_path = img_path.replace('images', 'labels').rsplit('.', 1)[0] + '.txt'
    if os.path.exists(label_path):
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.split()
                if parts:
                    class_id = int(parts[0])
                    if class_id < len(class_counts):
                        class_counts[class_id] += 1

class_counts = [c for c in class_counts if c > 0]
class_labels = [f'Class {i}' for i in range(len(class_counts))]

if len(class_counts) > 1:
    colors = plt.cm.Set3(range(len(class_counts)))
    ax.pie(class_counts, labels=class_labels, autopct='%1.1f%%', colors=colors, startangle=90)
    ax.set_title('Class Distribution', fontsize=12, fontweight='bold')
else:
    ax.text(0.5, 0.5, f'Single Class Dataset\nTotal Objects: {sum(class_counts)}',
            ha='center', va='center', fontsize=14, transform=ax.transAxes)
    ax.set_title('Class Distribution', fontsize=12, fontweight='bold')
    ax.axis('off')

plt.tight_layout()
plt.show()

# --- DETAILED BREAKDOWN ---
print("\nπŸ“‹ Detailed Breakdown:")
print("-" * 40)
unique, counts = np.unique(object_counts, return_counts=True)
for obj_count, num_images in zip(unique, counts):
    pct = (num_images / len(image_paths)) * 100
    print(f"{int(obj_count)} object(s):  {int(num_images):>3} images ({pct:>5.1f}%)")
πŸ“Š Object Count Statistics:
----------------------------------------
Total Images.................. 590
Images without objects........ 278
Images with objects........... 312
Total objects................. 343
Mean objects/image............ 0.58
Median objects/image.......... 1
Min objects................... 0
Max objects................... 6
Std Dev....................... 0.66

No description has been provided for this image
πŸ“‹ Detailed Breakdown:
----------------------------------------
0 object(s):  278 images ( 47.1%)
1 object(s):  295 images ( 50.0%)
2 object(s):    9 images (  1.5%)
3 object(s):    5 images (  0.8%)
4 object(s):    1 images (  0.2%)
5 object(s):    1 images (  0.2%)
6 object(s):    1 images (  0.2%)

What we visualized:

  1. Histogram β€” Distribution of objects per image (red=mean, green=median)
  2. Box Plot β€” Quartiles, outliers, and spread
  3. Cumulative Distribution β€” What % of images have ≀X objects
  4. Class Distribution Pie β€” Breakdown by class (single-class or multi-class)

What to look for:

  • Balanced β€” Similar counts across images (good for training)
  • Skewed β€” Most images have 1 object, few have many (may need data augmentation)
  • Empty images β€” Images with 0 objects (check if intentional negatives)
  • Outliers β€” Images with unusually many objects (harder to train on)
InΒ [Β ]:
%pwd
Out[Β ]:
'/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection'
InΒ [Β ]:
"""
YOLO Dataset Exploratory Data Analysis Tool
Comprehensive visual analysis of YOLO object detection datasets
"""

import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import cv2
from collections import defaultdict

class YOLODatasetEDA:
    """
    Comprehensive EDA for YOLO datasets with visualization.
    """

    def __init__(self, yaml_path, dataset_root=None):
        """
        Initialize EDA tool.

        Parameters:
        -----------
        yaml_path : str
            Path to data.yaml file
        dataset_root : str, optional
            Root directory of dataset (if yaml paths are relative)
        """
        self.yaml_path = yaml_path

        # Load YAML
        with open(yaml_path, 'r') as f:
            self.config = yaml.safe_load(f)

        # Set dataset root
        if dataset_root is None:
            dataset_root = os.path.dirname(yaml_path)
        self.dataset_root = dataset_root

        # Get class names
        self.class_names = self.config['names']
        self.num_classes = self.config['nc']

        print(f"πŸ“ Dataset loaded: {self.num_classes} class(es)")
        print(f"   Classes: {self.class_names}")

    def _get_split_paths(self, split):
        """Get image and label paths for a split."""
        # Handle relative paths
        img_path = self.config[split]
        if img_path.startswith('..'):
            img_path = os.path.join(self.dataset_root, img_path.lstrip('../'))

        # Get labels path (replace /images with /labels)
        label_path = img_path.replace('/images', '/labels')

        return img_path, label_path

    def analyze_dataset_splits(self):
        """
        Analyze dataset splits: train, val, test.
        Returns statistics for each split.
        """
        splits_data = {}

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(img_path):
                print(f"⚠️  Warning: {split} images not found at {img_path}")
                continue

            # Count images
            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            num_images = len(image_files)

            # Count annotations
            total_annotations = 0
            class_counts = defaultdict(int)
            images_with_annotations = 0

            if os.path.exists(label_path):
                for img_file in image_files:
                    label_file = os.path.splitext(img_file)[0] + '.txt'
                    label_file_path = os.path.join(label_path, label_file)

                    if os.path.exists(label_file_path):
                        with open(label_file_path, 'r') as f:
                            lines = f.readlines()
                            if lines:
                                images_with_annotations += 1
                                for line in lines:
                                    if line.strip():
                                        parts = line.strip().split()
                                        if parts:
                                            cls_id = int(parts[0])
                                            class_counts[cls_id] += 1
                                            total_annotations += 1

            splits_data[split] = {
                'num_images': num_images,
                'total_annotations': total_annotations,
                'images_with_annotations': images_with_annotations,
                'class_counts': dict(class_counts),
                'avg_annotations_per_image': total_annotations / num_images if num_images > 0 else 0
            }

            print(f"\n{split.upper()} split:")
            print(f"  Images: {num_images}")
            print(f"  Total annotations: {total_annotations}")
            print(f"  Images with annotations: {images_with_annotations}")
            print(f"  Avg annotations/image: {splits_data[split]['avg_annotations_per_image']:.2f}")

        return splits_data

    def plot_split_statistics(self, splits_data):
        """
        Plot 1: Dataset split bar chart with class annotations.
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        splits = list(splits_data.keys())

        # Plot 1: Number of images per split
        num_images = [splits_data[s]['num_images'] for s in splits]
        colors = ['#3498db', '#e74c3c', '#2ecc71']

        bars = ax1.bar(splits, num_images, color=colors[:len(splits)],
                      edgecolor='black', alpha=0.7)
        ax1.set_ylabel('Number of Images', fontsize=12)
        ax1.set_title('Dataset Split Distribution', fontsize=14, fontweight='bold')
        ax1.grid(True, alpha=0.3, axis='y')

        # Add value labels on bars
        for bar, count in zip(bars, num_images):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(count)}',
                    ha='center', va='bottom', fontsize=12, fontweight='bold')

        # Plot 2: Annotations per split (stacked by class)
        annotations_by_class = {}
        for cls_id in range(self.num_classes):
            annotations_by_class[cls_id] = [
                splits_data[s]['class_counts'].get(cls_id, 0) for s in splits
            ]

        bottom = np.zeros(len(splits))
        colors_classes = plt.cm.Set3(np.linspace(0, 1, self.num_classes))

        for cls_id in range(self.num_classes):
            class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
            counts = annotations_by_class[cls_id]
            ax2.bar(splits, counts, bottom=bottom, label=class_name,
                   color=colors_classes[cls_id], edgecolor='black', alpha=0.8)

            # Add value labels
            for i, (split_name, count) in enumerate(zip(splits, counts)):
                if count > 0:
                    ax2.text(i, bottom[i] + count/2, str(count),
                            ha='center', va='center', fontsize=10, fontweight='bold')

            bottom += counts

        ax2.set_ylabel('Number of Annotations', fontsize=12)
        ax2.set_title('Annotations per Split (by Class)', fontsize=14, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

    def analyze_image_dimensions(self):
        """
        Analyze image dimensions across all splits.
        """
        widths = []
        heights = []
        aspect_ratios = []

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, _ = self._get_split_paths(split)

            if not os.path.exists(img_path):
                continue

            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            # Sample up to 100 images for speed
            sampled_files = np.random.choice(image_files,
                                           min(100, len(image_files)),
                                           replace=False)

            for img_file in sampled_files:
                img_full_path = os.path.join(img_path, img_file)
                try:
                    with Image.open(img_full_path) as img:
                        w, h = img.size
                        widths.append(w)
                        heights.append(h)
                        aspect_ratios.append(w / h)
                except:
                    continue

        return np.array(widths), np.array(heights), np.array(aspect_ratios)

    def plot_image_dimensions(self, widths, heights, aspect_ratios):
        """
        Plot 2: Image dimensions and aspect ratios.
        """
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Scatter plot: width vs height
        ax = axes[0]
        scatter = ax.scatter(widths, heights, alpha=0.5, s=50, c=aspect_ratios,
                           cmap='viridis', edgecolors='black', linewidth=0.5)
        ax.set_xlabel('Width (pixels)', fontsize=12)
        ax.set_ylabel('Height (pixels)', fontsize=12)
        ax.set_title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)

        # Add average lines
        avg_w, avg_h = widths.mean(), heights.mean()
        ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
                  label=f'Avg W: {avg_w:.0f}')
        ax.axhline(avg_h, color='blue', linestyle='--', linewidth=2,
                  label=f'Avg H: {avg_h:.0f}')
        ax.legend()

        # Colorbar
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Aspect Ratio (W/H)', fontsize=10)

        # Histogram: widths
        ax = axes[1]
        ax.hist(widths, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
        ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
                  label=f'Mean: {avg_w:.0f}')
        ax.axvline(np.median(widths), color='green', linestyle='--', linewidth=2,
                  label=f'Median: {np.median(widths):.0f}')
        ax.set_xlabel('Width (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Width Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        # Histogram: heights
        ax = axes[2]
        ax.hist(heights, bins=30, color='lightcoral', edgecolor='black', alpha=0.7)
        ax.axvline(avg_h, color='blue', linestyle='--', linewidth=2,
                  label=f'Mean: {avg_h:.0f}')
        ax.axvline(np.median(heights), color='green', linestyle='--', linewidth=2,
                  label=f'Median: {np.median(heights):.0f}')
        ax.set_xlabel('Height (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Height Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"\nπŸ“ IMAGE DIMENSION STATISTICS:")
        print(f"   Width:  mean={widths.mean():.0f}, median={np.median(widths):.0f}, "
              f"std={widths.std():.0f}")
        print(f"   Height: mean={heights.mean():.0f}, median={np.median(heights):.0f}, "
              f"std={heights.std():.0f}")
        print(f"   Aspect Ratio: mean={aspect_ratios.mean():.2f}, "
              f"median={np.median(aspect_ratios):.2f}")


    def create_annotation_heatmap(self, grid_size=20):
        """
        Plot 3: Heatmap showing where annotations are located in images.
        Updated to use a blue-red colormap where:
        - Blue = lower frequency,
        - Orange = intermediate frequency,
        - Red = higher frequency.
        """
        # Create grid for heatmap (normalized coordinates 0-1)
        heatmap = np.zeros((grid_size, grid_size))

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(label_path):
                continue

            label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]

            for label_file in label_files:
                label_file_path = os.path.join(label_path, label_file)

                with open(label_file_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            # YOLO format: class x_center y_center width height (normalized)
                            x_center = float(parts[1])
                            y_center = float(parts[2])

                            # Map to grid
                            grid_x = int(x_center * grid_size)
                            grid_y = int(y_center * grid_size)

                            # Clamp to valid range
                            grid_x = max(0, min(grid_size - 1, grid_x))
                            grid_y = max(0, min(grid_size - 1, grid_y))

                            heatmap[grid_y, grid_x] += 1

        # Define a custom colormap: Blue to Red, with intermediate Orange
        cmap = mcolors.LinearSegmentedColormap.from_list(
            "blue_red", ["blue", "orange", "red"], N=256
        )

        # Plot heatmap with the updated colormap
        fig, ax = plt.subplots(figsize=(10, 8))

        im = ax.imshow(heatmap, cmap=cmap, interpolation='bilinear', origin='upper')
        ax.set_xlabel('Normalized X Position', fontsize=12)
        ax.set_ylabel('Normalized Y Position', fontsize=12)
        ax.set_title('Annotation Center Heatmap (All Splits)',
                    fontsize=14, fontweight='bold')

        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Annotation Density', fontsize=12)

        # Add grid
        ax.set_xticks(np.arange(0, grid_size, grid_size//5))
        ax.set_yticks(np.arange(0, grid_size, grid_size//5))
        ax.set_xticklabels([f'{x/grid_size:.1f}' for x in range(0, grid_size, grid_size//5)])
        ax.set_yticklabels([f'{y/grid_size:.1f}' for y in range(0, grid_size, grid_size//5)])
        ax.grid(True, alpha=0.3, color='white', linewidth=1)

        plt.tight_layout()
        plt.show()

        return heatmap


    def analyze_objects_per_image(self):
        """
        Analyze distribution of object counts per image.
        """
        object_counts = []

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(img_path) or not os.path.exists(label_path):
                continue

            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            for img_file in image_files:
                label_file = os.path.splitext(img_file)[0] + '.txt'
                label_file_path = os.path.join(label_path, label_file)

                count = 0
                if os.path.exists(label_file_path):
                    with open(label_file_path, 'r') as f:
                        count = len([line for line in f if line.strip()])

                object_counts.append(count)

        return np.array(object_counts)

    def plot_objects_per_image(self, object_counts):
        """
        Plot 4: Histogram of object count per image.
        """
        fig, ax = plt.subplots(figsize=(12, 6))

        max_count = object_counts.max()
        bins = range(0, max_count + 2)

        ax.hist(object_counts, bins=bins, edgecolor='black',
               color='mediumseagreen', alpha=0.7)

        # Statistics
        mean_count = object_counts.mean()
        median_count = np.median(object_counts)

        ax.axvline(mean_count, color='red', linestyle='--', linewidth=2,
                  label=f'Mean: {mean_count:.2f}')
        ax.axvline(median_count, color='blue', linestyle='--', linewidth=2,
                  label=f'Median: {median_count:.0f}')

        ax.set_xlabel('Number of Objects per Image', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Distribution of Object Count per Image',
                    fontsize=14, fontweight='bold')
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"\nπŸ“¦ OBJECTS PER IMAGE STATISTICS:")
        print(f"   Mean: {mean_count:.2f}")
        print(f"   Median: {median_count:.0f}")
        print(f"   Max: {max_count}")
        print(f"   Images with 0 objects: {np.sum(object_counts == 0)}")
        print(f"   Images with 1 object: {np.sum(object_counts == 1)}")
        print(f"   Images with 2+ objects: {np.sum(object_counts >= 2)}")

    def visualize_sample_images(self, num_samples=6, split='train'):
        """
        Plot 5: Sample images with bounding boxes.
        """
        img_path, label_path = self._get_split_paths(split)

        if not os.path.exists(img_path):
            print(f"⚠️  {split} images not found")
            return

        image_files = [f for f in os.listdir(img_path)
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        # Sample random images
        sampled_files = np.random.choice(image_files,
                                        min(num_samples, len(image_files)),
                                        replace=False)

        # Create grid
        rows = 2
        cols = 3
        fig, axes = plt.subplots(rows, cols, figsize=(18, 12))
        axes = axes.flatten()

        for idx, img_file in enumerate(sampled_files):
            if idx >= rows * cols:
                break

            # Load image
            img_full_path = os.path.join(img_path, img_file)
            img = cv2.imread(img_full_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            h, w = img.shape[:2]

            # Load annotations
            label_file = os.path.splitext(img_file)[0] + '.txt'
            label_file_path = os.path.join(label_path, label_file)

            if os.path.exists(label_file_path):
                with open(label_file_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            cls_id = int(parts[0])
                            x_center = float(parts[1]) * w
                            y_center = float(parts[2]) * h
                            box_w = float(parts[3]) * w
                            box_h = float(parts[4]) * h

                            # Convert to corner coordinates
                            x1 = int(x_center - box_w / 2)
                            y1 = int(y_center - box_h / 2)
                            x2 = int(x_center + box_w / 2)
                            y2 = int(y_center + box_h / 2)

                            # Draw rectangle
                            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)

                            # Add label
                            class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
                            cv2.putText(img, class_name, (x1, y1 - 10),
                                      cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

            # Display
            axes[idx].imshow(img)
            axes[idx].set_title(f'{split}: {img_file}', fontsize=10)
            axes[idx].axis('off')

        # Hide unused subplots
        for idx in range(len(sampled_files), rows * cols):
            axes[idx].axis('off')

        plt.suptitle(f'Sample Images with Annotations ({split.upper()} split)',
                    fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        plt.show()

    def run_full_eda(self):
        """
        Run complete EDA pipeline with all visualizations.
        """
        print("="*70)
        print("πŸ” YOLO DATASET EXPLORATORY DATA ANALYSIS")
        print("="*70)

        # 1. Dataset splits analysis
        print("\nπŸ“Š Analyzing dataset splits...")
        splits_data = self.analyze_dataset_splits()
        self.plot_split_statistics(splits_data)

        # # 2. Image dimensions
        # print("\nπŸ“ Analyzing image dimensions...")
        # widths, heights, aspect_ratios = self.analyze_image_dimensions()
        # self.plot_image_dimensions(widths, heights, aspect_ratios)

        # 3. Annotation heatmap
        print("\nπŸ—ΊοΈ  Creating annotation heatmap...")
        self.create_annotation_heatmap(grid_size=20)

        # 4. Objects per image
        print("\nπŸ“¦ Analyzing objects per image...")
        object_counts = self.analyze_objects_per_image()
        self.plot_objects_per_image(object_counts)

        # 5. Sample images
        print("\nπŸ–ΌοΈ  Visualizing sample images...")
        for split in ['train', 'val']:
            if split in self.config:
                print(f"\n   {split.upper()} samples:")
                self.visualize_sample_images(num_samples=6, split=split)

        print("\n" + "="*70)
        print("βœ… EDA COMPLETE!")
        print("="*70)


# =======================
# USAGE
# =======================

# Initialize EDA
yaml_path = '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/Santa-9/data.yaml'  # Update this path
dataset_root = '/content/drive/My Drive/FHNW/HS_25/DLBS/minichallenge_hs25_object_detection/Santa-9'  # Update if needed

eda = YOLODatasetEDA(yaml_path, dataset_root)

# Run full analysis
eda.run_full_eda()

# Or run individual analyses
# splits_data = eda.analyze_dataset_splits()
# eda.plot_split_statistics(splits_data)
# widths, heights, aspect_ratios = eda.analyze_image_dimensions()
# eda.plot_image_dimensions(widths, heights, aspect_ratios)
Output hidden; open in https://colab.research.google.com to view.
InΒ [Β ]:
"""
YOLO Dataset Exploratory Data Analysis Tool
Comprehensive visual analysis of YOLO object detection datasets
"""

import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import cv2
from collections import defaultdict

class YOLODatasetEDA:
    """
    Comprehensive EDA for YOLO datasets with visualization.
    """

    def __init__(self, yaml_path, dataset_root=None):
        """
        Initialize EDA tool.

        Parameters:
        -----------
        yaml_path : str
            Path to data.yaml file
        dataset_root : str, optional
            Root directory of dataset (if yaml paths are relative)
        """
        self.yaml_path = yaml_path

        # Load YAML
        with open(yaml_path, 'r') as f:
            self.config = yaml.safe_load(f)

        # Set dataset root
        if dataset_root is None:
            dataset_root = os.path.dirname(yaml_path)
        self.dataset_root = dataset_root

        # Get class names
        self.class_names = self.config['names']
        self.num_classes = self.config['nc']

        print(f"πŸ“ Dataset loaded: {self.num_classes} class(es)")
        print(f"   Classes: {self.class_names}")

    def _get_split_paths(self, split):
        """Get image and label paths for a split."""
        # Handle relative paths
        img_path = self.config[split]
        if img_path.startswith('..'):
            img_path = os.path.join(self.dataset_root, img_path.lstrip('../'))

        # Get labels path (replace /images with /labels)
        label_path = img_path.replace('/images', '/labels')

        return img_path, label_path

    def analyze_dataset_splits(self):
        """
        Analyze dataset splits: train, val, test.
        Returns statistics for each split.
        """
        splits_data = {}

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(img_path):
                print(f"⚠️  Warning: {split} images not found at {img_path}")
                continue

            # Count images
            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            num_images = len(image_files)

            # Count annotations
            total_annotations = 0
            class_counts = defaultdict(int)
            images_with_annotations = 0

            if os.path.exists(label_path):
                for img_file in image_files:
                    label_file = os.path.splitext(img_file)[0] + '.txt'
                    label_file_path = os.path.join(label_path, label_file)

                    if os.path.exists(label_file_path):
                        with open(label_file_path, 'r') as f:
                            lines = f.readlines()
                            if lines:
                                images_with_annotations += 1
                                for line in lines:
                                    if line.strip():
                                        parts = line.strip().split()
                                        if parts:
                                            cls_id = int(parts[0])
                                            class_counts[cls_id] += 1
                                            total_annotations += 1

            splits_data[split] = {
                'num_images': num_images,
                'total_annotations': total_annotations,
                'images_with_annotations': images_with_annotations,
                'class_counts': dict(class_counts),
                'avg_annotations_per_image': total_annotations / num_images if num_images > 0 else 0
            }

            print(f"\n{split.upper()} split:")
            print(f"  Images: {num_images}")
            print(f"  Total annotations: {total_annotations}")
            print(f"  Images with annotations: {images_with_annotations}")
            print(f"  Avg annotations/image: {splits_data[split]['avg_annotations_per_image']:.2f}")

        return splits_data

    def plot_split_statistics(self, splits_data):
        """
        Plot 1: Dataset split bar chart with class annotations.
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        splits = list(splits_data.keys())

        # Plot 1: Number of images per split
        num_images = [splits_data[s]['num_images'] for s in splits]
        colors = ['#3498db', '#e74c3c', '#2ecc71']

        bars = ax1.bar(splits, num_images, color=colors[:len(splits)],
                      edgecolor='black', alpha=0.7)
        ax1.set_ylabel('Number of Images', fontsize=12)
        ax1.set_title('Dataset Split Distribution', fontsize=14, fontweight='bold')
        ax1.grid(True, alpha=0.3, axis='y')

        # Add value labels on bars
        for bar, count in zip(bars, num_images):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{int(count)}',
                    ha='center', va='bottom', fontsize=12, fontweight='bold')

        # Plot 2: Annotations per split (stacked by class)
        annotations_by_class = {}
        for cls_id in range(self.num_classes):
            annotations_by_class[cls_id] = [
                splits_data[s]['class_counts'].get(cls_id, 0) for s in splits
            ]

        bottom = np.zeros(len(splits))
        colors_classes = plt.cm.Set3(np.linspace(0, 1, self.num_classes))

        for cls_id in range(self.num_classes):
            class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
            counts = annotations_by_class[cls_id]
            ax2.bar(splits, counts, bottom=bottom, label=class_name,
                   color=colors_classes[cls_id], edgecolor='black', alpha=0.8)

            # Add value labels
            for i, (split_name, count) in enumerate(zip(splits, counts)):
                if count > 0:
                    ax2.text(i, bottom[i] + count/2, str(count),
                            ha='center', va='center', fontsize=10, fontweight='bold')

            bottom += counts

        ax2.set_ylabel('Number of Annotations', fontsize=12)
        ax2.set_title('Annotations per Split (by Class)', fontsize=14, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

    def analyze_image_dimensions(self):
        """
        Analyze image dimensions across all splits.
        """
        widths = []
        heights = []
        aspect_ratios = []

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, _ = self._get_split_paths(split)

            if not os.path.exists(img_path):
                continue

            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            # Sample up to 100 images for speed
            sampled_files = np.random.choice(image_files,
                                           min(100, len(image_files)),
                                           replace=False)

            for img_file in sampled_files:
                img_full_path = os.path.join(img_path, img_file)
                try:
                    with Image.open(img_full_path) as img:
                        w, h = img.size
                        widths.append(w)
                        heights.append(h)
                        aspect_ratios.append(w / h)
                except:
                    continue

        return np.array(widths), np.array(heights), np.array(aspect_ratios)

    def plot_image_dimensions(self, widths, heights, aspect_ratios):
        """
        Plot 2: Image dimensions and aspect ratios.
        """
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Scatter plot: width vs height
        ax = axes[0]
        scatter = ax.scatter(widths, heights, alpha=0.5, s=50, c=aspect_ratios,
                           cmap='viridis', edgecolors='black', linewidth=0.5)
        ax.set_xlabel('Width (pixels)', fontsize=12)
        ax.set_ylabel('Height (pixels)', fontsize=12)
        ax.set_title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)

        # Add average lines
        avg_w, avg_h = widths.mean(), heights.mean()
        ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
                  label=f'Avg W: {avg_w:.0f}')
        ax.axhline(avg_h, color='blue', linestyle='--', linewidth=2,
                  label=f'Avg H: {avg_h:.0f}')
        ax.legend()

        # Colorbar
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Aspect Ratio (W/H)', fontsize=10)

        # Histogram: widths
        ax = axes[1]
        ax.hist(widths, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
        ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
                  label=f'Mean: {avg_w:.0f}')
        ax.axvline(np.median(widths), color='green', linestyle='--', linewidth=2,
                  label=f'Median: {np.median(widths):.0f}')
        ax.set_xlabel('Width (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Width Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        # Histogram: heights
        ax = axes[2]
        ax.hist(heights, bins=30, color='lightcoral', edgecolor='black', alpha=0.7)
        ax.axvline(avg_h, color='blue', linestyle='--', linewidth=2,
                  label=f'Mean: {avg_h:.0f}')
        ax.axvline(np.median(heights), color='green', linestyle='--', linewidth=2,
                  label=f'Median: {np.median(heights):.0f}')
        ax.set_xlabel('Height (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Height Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"\nπŸ“ IMAGE DIMENSION STATISTICS:")
        print(f"   Width:  mean={widths.mean():.0f}, median={np.median(widths):.0f}, "
              f"std={widths.std():.0f}")
        print(f"   Height: mean={heights.mean():.0f}, median={np.median(heights):.0f}, "
              f"std={heights.std():.0f}")
        print(f"   Aspect Ratio: mean={aspect_ratios.mean():.2f}, "
              f"median={np.median(aspect_ratios):.2f}")

    def create_annotation_heatmap(self, grid_size=20):
        """
        Plot 3: Heatmap showing where annotations are located in images.
        """
        # Create grid for heatmap (normalized coordinates 0-1)
        heatmap = np.zeros((grid_size, grid_size))

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(label_path):
                continue

            label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]

            for label_file in label_files:
                label_file_path = os.path.join(label_path, label_file)

                with open(label_file_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            # YOLO format: class x_center y_center width height (normalized)
                            x_center = float(parts[1])
                            y_center = float(parts[2])

                            # Map to grid
                            grid_x = int(x_center * grid_size)
                            grid_y = int(y_center * grid_size)

                            # Clamp to valid range
                            grid_x = max(0, min(grid_size - 1, grid_x))
                            grid_y = max(0, min(grid_size - 1, grid_y))

                            heatmap[grid_y, grid_x] += 1

        # Plot heatmap
        fig, ax = plt.subplots(figsize=(10, 8))

        im = ax.imshow(heatmap, cmap='hot', interpolation='bilinear', origin='upper')
        ax.set_xlabel('Normalized X Position', fontsize=12)
        ax.set_ylabel('Normalized Y Position', fontsize=12)
        ax.set_title('Annotation Center Heatmap (All Splits)',
                    fontsize=14, fontweight='bold')

        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Annotation Density', fontsize=12)

        # Add grid
        ax.set_xticks(np.arange(0, grid_size, grid_size//5))
        ax.set_yticks(np.arange(0, grid_size, grid_size//5))
        ax.set_xticklabels([f'{x/grid_size:.1f}' for x in range(0, grid_size, grid_size//5)])
        ax.set_yticklabels([f'{y/grid_size:.1f}' for y in range(0, grid_size, grid_size//5)])
        ax.grid(True, alpha=0.3, color='white', linewidth=1)

        plt.tight_layout()
        plt.show()

        return heatmap

    def analyze_objects_per_image(self):
        """
        Analyze distribution of object counts per image.
        """
        object_counts = []

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(img_path) or not os.path.exists(label_path):
                continue

            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            for img_file in image_files:
                label_file = os.path.splitext(img_file)[0] + '.txt'
                label_file_path = os.path.join(label_path, label_file)

                count = 0
                if os.path.exists(label_file_path):
                    with open(label_file_path, 'r') as f:
                        count = len([line for line in f if line.strip()])

                object_counts.append(count)

        return np.array(object_counts)

    def plot_objects_per_image(self, object_counts):
        """
        Plot 4: Histogram of object count per image.
        """
        fig, ax = plt.subplots(figsize=(12, 6))

        max_count = object_counts.max()
        bins = range(0, max_count + 2)

        ax.hist(object_counts, bins=bins, edgecolor='black',
               color='mediumseagreen', alpha=0.7)

        # Statistics
        mean_count = object_counts.mean()
        median_count = np.median(object_counts)

        ax.axvline(mean_count, color='red', linestyle='--', linewidth=2,
                  label=f'Mean: {mean_count:.2f}')
        ax.axvline(median_count, color='blue', linestyle='--', linewidth=2,
                  label=f'Median: {median_count:.0f}')

        ax.set_xlabel('Number of Objects per Image', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Distribution of Object Count per Image',
                    fontsize=14, fontweight='bold')
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"\nπŸ“¦ OBJECTS PER IMAGE STATISTICS:")
        print(f"   Mean: {mean_count:.2f}")
        print(f"   Median: {median_count:.0f}")
        print(f"   Max: {max_count}")
        print(f"   Images with 0 objects: {np.sum(object_counts == 0)}")
        print(f"   Images with 1 object: {np.sum(object_counts == 1)}")
        print(f"   Images with 2+ objects: {np.sum(object_counts >= 2)}")

    def visualize_sample_images(self, num_samples=6, split='train'):
        """
        Plot 5: Sample images with bounding boxes.
        """
        img_path, label_path = self._get_split_paths(split)

        if not os.path.exists(img_path):
            print(f"⚠️  {split} images not found")
            return

        image_files = [f for f in os.listdir(img_path)
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        # Sample random images
        sampled_files = np.random.choice(image_files,
                                        min(num_samples, len(image_files)),
                                        replace=False)

        # Create grid
        rows = 2
        cols = 3
        fig, axes = plt.subplots(rows, cols, figsize=(18, 12))
        axes = axes.flatten()

        for idx, img_file in enumerate(sampled_files):
            if idx >= rows * cols:
                break

            # Load image
            img_full_path = os.path.join(img_path, img_file)
            img = cv2.imread(img_full_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            h, w = img.shape[:2]

            # Load annotations
            label_file = os.path.splitext(img_file)[0] + '.txt'
            label_file_path = os.path.join(label_path, label_file)

            if os.path.exists(label_file_path):
                with open(label_file_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            cls_id = int(parts[0])
                            x_center = float(parts[1]) * w
                            y_center = float(parts[2]) * h
                            box_w = float(parts[3]) * w
                            box_h = float(parts[4]) * h

                            # Convert to corner coordinates
                            x1 = int(x_center - box_w / 2)
                            y1 = int(y_center - box_h / 2)
                            x2 = int(x_center + box_w / 2)
                            y2 = int(y_center + box_h / 2)

                            # Draw rectangle
                            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)

                            # Add label
                            class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
                            cv2.putText(img, class_name, (x1, y1 - 10),
                                      cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

            # Display
            axes[idx].imshow(img)
            axes[idx].set_title(f'{split}: {img_file}', fontsize=10)
            axes[idx].axis('off')

        # Hide unused subplots
        for idx in range(len(sampled_files), rows * cols):
            axes[idx].axis('off')

        plt.suptitle(f'Sample Images with Annotations ({split.upper()} split)',
                    fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        plt.show()

    def run_full_eda(self):
        """
        Run complete EDA pipeline with all visualizations.
        """
        print("="*70)
        print("πŸ” YOLO DATASET EXPLORATORY DATA ANALYSIS")
        print("="*70)

        # 1. Dataset splits analysis
        print("\nπŸ“Š Analyzing dataset splits...")
        splits_data = self.analyze_dataset_splits()
        self.plot_split_statistics(splits_data)

        # 2. Image dimensions
        print("\nπŸ“ Analyzing image dimensions...")
        widths, heights, aspect_ratios = self.analyze_image_dimensions()
        self.plot_image_dimensions(widths, heights, aspect_ratios)

        # 3. Annotation heatmap
        print("\nπŸ—ΊοΈ  Creating annotation heatmap...")
        self.create_annotation_heatmap(grid_size=20)

        # 4. Objects per image
        print("\nπŸ“¦ Analyzing objects per image...")
        object_counts = self.analyze_objects_per_image()
        self.plot_objects_per_image(object_counts)

        # 5. Sample images
        print("\nπŸ–ΌοΈ  Visualizing sample images...")
        for split in ['train', 'val']:
            if split in self.config:
                print(f"\n   {split.upper()} samples:")
                self.visualize_sample_images(num_samples=6, split=split)

        print("\n" + "="*70)
        print("βœ… EDA COMPLETE!")
        print("="*70)


# =======================
# USAGE
# =======================

# Initialize EDA
yaml_path = yaml_path  # Update this path
dataset_root = dataset_root  # Update if needed

eda = YOLODatasetEDA(yaml_path, dataset_root)

# Run full analysis
eda.run_full_eda()

# Or run individual analyses
# splits_data = eda.analyze_dataset_splits()
# eda.plot_split_statistics(splits_data)
# widths, heights, aspect_ratios = eda.analyze_image_dimensions()
# eda.plot_image_dimensions(widths, heights, aspect_ratios)
Output hidden; open in https://colab.research.google.com to view.
InΒ [Β ]:
"""
YOLO Dataset Exploratory Data Analysis Tool
Comprehensive visual analysis of YOLO object detection datasets
"""

import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
import cv2
from collections import defaultdict

class YOLODatasetEDA:
    """
    Comprehensive EDA for YOLO datasets with visualization.
    """

    def __init__(self, yaml_path, dataset_root=None):
        """
        Initialize EDA tool.

        Parameters:
        -----------
        yaml_path : str
            Path to data.yaml file
        dataset_root : str, optional
            Root directory of dataset (if yaml paths are relative)
        """
        self.yaml_path = yaml_path

        # Load YAML
        with open(yaml_path, 'r') as f:
            self.config = yaml.safe_load(f)

        # Set dataset root
        if dataset_root is None:
            dataset_root = os.path.dirname(yaml_path)
        self.dataset_root = dataset_root

        # Get class names
        self.class_names = self.config['names']
        self.num_classes = self.config['nc']

        print(f"πŸ“ Dataset loaded: {self.num_classes} class(es)")
        print(f"   Classes: {self.class_names}")

    def _get_split_paths(self, split):
        """Get image and label paths for a split."""
        # Handle relative paths
        img_path = self.config[split]
        if img_path.startswith('..'):
            img_path = os.path.join(self.dataset_root, img_path.lstrip('../'))

        # Get labels path (replace /images with /labels)
        label_path = img_path.replace('/images', '/labels')

        return img_path, label_path

    def analyze_dataset_splits(self):
        """
        Analyze dataset splits: train, val, test.
        Returns statistics for each split.
        """
        splits_data = {}

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(img_path):
                print(f"⚠️  Warning: {split} images not found at {img_path}")
                continue

            # Count images
            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            num_images = len(image_files)

            # Count annotations
            total_annotations = 0
            class_counts = defaultdict(int)
            images_with_annotations = 0

            if os.path.exists(label_path):
                for img_file in image_files:
                    label_file = os.path.splitext(img_file)[0] + '.txt'
                    label_file_path = os.path.join(label_path, label_file)

                    if os.path.exists(label_file_path):
                        with open(label_file_path, 'r') as f:
                            lines = f.readlines()
                            if lines:
                                images_with_annotations += 1
                                for line in lines:
                                    if line.strip():
                                        parts = line.strip().split()
                                        if parts:
                                            cls_id = int(parts[0])
                                            class_counts[cls_id] += 1
                                            total_annotations += 1

            splits_data[split] = {
                'num_images': num_images,
                'total_annotations': total_annotations,
                'images_with_annotations': images_with_annotations,
                'class_counts': dict(class_counts),
                'avg_annotations_per_image': total_annotations / num_images if num_images > 0 else 0
            }

            print(f"\n{split.upper()} split:")
            print(f"  Images: {num_images}")
            print(f"  Total annotations: {total_annotations}")
            print(f"  Images with annotations: {images_with_annotations}")
            print(f"  Avg annotations/image: {splits_data[split]['avg_annotations_per_image']:.2f}")

        return splits_data

    def plot_split_statistics(self, splits_data):
        """
        Plot 1: Stacked bar chart - images with Santa vs without Santa per split.
        """
        fig, ax = plt.subplots(figsize=(10, 7))

        splits = list(splits_data.keys())
        x = np.arange(len(splits))

        # Get data
        images_with_santa = [splits_data[s]['images_with_annotations'] for s in splits]
        images_without_santa = [
            splits_data[s]['num_images'] - splits_data[s]['images_with_annotations']
            for s in splits
        ]
        total_images = [splits_data[s]['num_images'] for s in splits]

        # Colors
        color_with = '#e74c3c'      # Red for Santa images
        color_without = '#95a5a6'   # Gray for background/no Santa

        # Stacked bars
        bars1 = ax.bar(x, images_with_santa,
                      label='Images with Santa',
                      color=color_with,
                      edgecolor='black',
                      alpha=0.8,
                      linewidth=1.5)

        bars2 = ax.bar(x, images_without_santa,
                      bottom=images_with_santa,
                      label='Images without Santa (background)',
                      color=color_without,
                      edgecolor='black',
                      alpha=0.8,
                      linewidth=1.5)

        # Add value labels on bars
        for i, (with_santa, without_santa, total) in enumerate(
            zip(images_with_santa, images_without_santa, total_images)):

            # Label for "with Santa" section
            if with_santa > 0:
                ax.text(i, with_santa/2,
                       f'{with_santa}\n({100*with_santa/total:.1f}%)',
                       ha='center', va='center',
                       fontsize=11, fontweight='bold',
                       color='white')

            # Label for "without Santa" section
            if without_santa > 0:
                ax.text(i, with_santa + without_santa/2,
                       f'{without_santa}\n({100*without_santa/total:.1f}%)',
                       ha='center', va='center',
                       fontsize=11, fontweight='bold',
                       color='white')

            # Total on top
            ax.text(i, total, f'Total: {total}',
                   ha='center', va='bottom',
                   fontsize=12, fontweight='bold',
                   color='black')

        # Formatting
        ax.set_xlabel('Dataset Split', fontsize=13, fontweight='bold')
        ax.set_ylabel('Number of Images', fontsize=13, fontweight='bold')
        ax.set_title('Dataset Split Distribution: Images with/without Santa',
                    fontsize=15, fontweight='bold', pad=20)
        ax.set_xticks(x)
        ax.set_xticklabels([s.upper() for s in splits], fontsize=12, fontweight='bold')
        ax.legend(fontsize=11, loc='upper right', framealpha=0.9)
        ax.grid(True, alpha=0.3, axis='y')

        # Set y-axis to start at 0
        ax.set_ylim(bottom=0)

        plt.tight_layout()
        plt.show()

    def analyze_image_dimensions(self):
        """
        Analyze image dimensions across all splits.
        """
        widths = []
        heights = []
        aspect_ratios = []

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, _ = self._get_split_paths(split)

            if not os.path.exists(img_path):
                continue

            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            # Sample up to 100 images for speed
            sampled_files = np.random.choice(image_files,
                                           min(100, len(image_files)),
                                           replace=False)

            for img_file in sampled_files:
                img_full_path = os.path.join(img_path, img_file)
                try:
                    with Image.open(img_full_path) as img:
                        w, h = img.size
                        widths.append(w)
                        heights.append(h)
                        aspect_ratios.append(w / h)
                except:
                    continue

        return np.array(widths), np.array(heights), np.array(aspect_ratios)

    def plot_image_dimensions(self, widths, heights, aspect_ratios):
        """
        Plot 2: Image dimensions and aspect ratios.
        """
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))

        # Scatter plot: width vs height
        ax = axes[0]
        scatter = ax.scatter(widths, heights, alpha=0.5, s=50, c=aspect_ratios,
                           cmap='viridis', edgecolors='black', linewidth=0.5)
        ax.set_xlabel('Width (pixels)', fontsize=12)
        ax.set_ylabel('Height (pixels)', fontsize=12)
        ax.set_title('Image Dimensions Distribution', fontsize=14, fontweight='bold')
        ax.grid(True, alpha=0.3)

        # Add average lines
        avg_w, avg_h = widths.mean(), heights.mean()
        ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
                  label=f'Avg W: {avg_w:.0f}')
        ax.axhline(avg_h, color='blue', linestyle='--', linewidth=2,
                  label=f'Avg H: {avg_h:.0f}')
        ax.legend()

        # Colorbar
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Aspect Ratio (W/H)', fontsize=10)

        # Histogram: widths
        ax = axes[1]
        ax.hist(widths, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
        ax.axvline(avg_w, color='red', linestyle='--', linewidth=2,
                  label=f'Mean: {avg_w:.0f}')
        ax.axvline(np.median(widths), color='green', linestyle='--', linewidth=2,
                  label=f'Median: {np.median(widths):.0f}')
        ax.set_xlabel('Width (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Width Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        # Histogram: heights
        ax = axes[2]
        ax.hist(heights, bins=30, color='lightcoral', edgecolor='black', alpha=0.7)
        ax.axvline(avg_h, color='blue', linestyle='--', linewidth=2,
                  label=f'Mean: {avg_h:.0f}')
        ax.axvline(np.median(heights), color='green', linestyle='--', linewidth=2,
                  label=f'Median: {np.median(heights):.0f}')
        ax.set_xlabel('Height (pixels)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Height Distribution', fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"\nπŸ“ IMAGE DIMENSION STATISTICS:")
        print(f"   Width:  mean={widths.mean():.0f}, median={np.median(widths):.0f}, "
              f"std={widths.std():.0f}")
        print(f"   Height: mean={heights.mean():.0f}, median={np.median(heights):.0f}, "
              f"std={heights.std():.0f}")
        print(f"   Aspect Ratio: mean={aspect_ratios.mean():.2f}, "
              f"median={np.median(aspect_ratios):.2f}")

    def create_annotation_heatmap(self, grid_size=20):
        """
        Plot 3: Heatmap showing where annotations are located in images.
        """
        # Create grid for heatmap (normalized coordinates 0-1)
        heatmap = np.zeros((grid_size, grid_size))

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(label_path):
                continue

            label_files = [f for f in os.listdir(label_path) if f.endswith('.txt')]

            for label_file in label_files:
                label_file_path = os.path.join(label_path, label_file)

                with open(label_file_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            # YOLO format: class x_center y_center width height (normalized)
                            x_center = float(parts[1])
                            y_center = float(parts[2])

                            # Map to grid
                            grid_x = int(x_center * grid_size)
                            grid_y = int(y_center * grid_size)

                            # Clamp to valid range
                            grid_x = max(0, min(grid_size - 1, grid_x))
                            grid_y = max(0, min(grid_size - 1, grid_y))

                            heatmap[grid_y, grid_x] += 1

        # Plot heatmap
        fig, ax = plt.subplots(figsize=(10, 8))

        im = ax.imshow(heatmap, cmap='hot', interpolation='bilinear', origin='upper')
        ax.set_xlabel('Normalized X Position', fontsize=12)
        ax.set_ylabel('Normalized Y Position', fontsize=12)
        ax.set_title('Annotation Center Heatmap (All Splits)',
                    fontsize=14, fontweight='bold')

        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Annotation Density', fontsize=12)

        # Add grid
        ax.set_xticks(np.arange(0, grid_size, grid_size//5))
        ax.set_yticks(np.arange(0, grid_size, grid_size//5))
        ax.set_xticklabels([f'{x/grid_size:.1f}' for x in range(0, grid_size, grid_size//5)])
        ax.set_yticklabels([f'{y/grid_size:.1f}' for y in range(0, grid_size, grid_size//5)])
        ax.grid(True, alpha=0.3, color='white', linewidth=1)

        plt.tight_layout()
        plt.show()

        return heatmap

    def analyze_objects_per_image(self):
        """
        Analyze distribution of object counts per image.
        """
        object_counts = []

        for split in ['train', 'val', 'test']:
            if split not in self.config:
                continue

            img_path, label_path = self._get_split_paths(split)

            if not os.path.exists(img_path) or not os.path.exists(label_path):
                continue

            image_files = [f for f in os.listdir(img_path)
                          if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

            for img_file in image_files:
                label_file = os.path.splitext(img_file)[0] + '.txt'
                label_file_path = os.path.join(label_path, label_file)

                count = 0
                if os.path.exists(label_file_path):
                    with open(label_file_path, 'r') as f:
                        count = len([line for line in f if line.strip()])

                object_counts.append(count)

        return np.array(object_counts)

    def plot_objects_per_image(self, object_counts):
        """
        Plot 4: Histogram of object count per image.
        """
        fig, ax = plt.subplots(figsize=(12, 6))

        max_count = object_counts.max()
        bins = range(0, max_count + 2)

        ax.hist(object_counts, bins=bins, edgecolor='black',
               color='mediumseagreen', alpha=0.7)

        # Statistics
        mean_count = object_counts.mean()
        median_count = np.median(object_counts)

        ax.axvline(mean_count, color='red', linestyle='--', linewidth=2,
                  label=f'Mean: {mean_count:.2f}')
        ax.axvline(median_count, color='blue', linestyle='--', linewidth=2,
                  label=f'Median: {median_count:.0f}')

        ax.set_xlabel('Number of Objects per Image', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title('Distribution of Object Count per Image',
                    fontsize=14, fontweight='bold')
        ax.legend(fontsize=12)
        ax.grid(True, alpha=0.3, axis='y')

        plt.tight_layout()
        plt.show()

        # Print statistics
        print(f"\nπŸ“¦ OBJECTS PER IMAGE STATISTICS:")
        print(f"   Mean: {mean_count:.2f}")
        print(f"   Median: {median_count:.0f}")
        print(f"   Max: {max_count}")
        print(f"   Images with 0 objects: {np.sum(object_counts == 0)}")
        print(f"   Images with 1 object: {np.sum(object_counts == 1)}")
        print(f"   Images with 2+ objects: {np.sum(object_counts >= 2)}")

    def visualize_sample_images(self, num_samples=6, split='train'):
        """
        Plot 5: Sample images with bounding boxes.
        """
        img_path, label_path = self._get_split_paths(split)

        if not os.path.exists(img_path):
            print(f"⚠️  {split} images not found")
            return

        image_files = [f for f in os.listdir(img_path)
                      if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        # Sample random images
        sampled_files = np.random.choice(image_files,
                                        min(num_samples, len(image_files)),
                                        replace=False)

        # Create grid
        rows = 2
        cols = 3
        fig, axes = plt.subplots(rows, cols, figsize=(18, 12))
        axes = axes.flatten()

        for idx, img_file in enumerate(sampled_files):
            if idx >= rows * cols:
                break

            # Load image
            img_full_path = os.path.join(img_path, img_file)
            img = cv2.imread(img_full_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            h, w = img.shape[:2]

            # Load annotations
            label_file = os.path.splitext(img_file)[0] + '.txt'
            label_file_path = os.path.join(label_path, label_file)

            if os.path.exists(label_file_path):
                with open(label_file_path, 'r') as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 5:
                            cls_id = int(parts[0])
                            x_center = float(parts[1]) * w
                            y_center = float(parts[2]) * h
                            box_w = float(parts[3]) * w
                            box_h = float(parts[4]) * h

                            # Convert to corner coordinates
                            x1 = int(x_center - box_w / 2)
                            y1 = int(y_center - box_h / 2)
                            x2 = int(x_center + box_w / 2)
                            y2 = int(y_center + box_h / 2)

                            # Draw rectangle
                            cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)

                            # Add label
                            class_name = self.class_names[cls_id] if isinstance(self.class_names, list) else self.class_names[cls_id]
                            cv2.putText(img, class_name, (x1, y1 - 10),
                                      cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)

            # Display
            axes[idx].imshow(img)
            axes[idx].set_title(f'{split}: {img_file}', fontsize=10)
            axes[idx].axis('off')

        # Hide unused subplots
        for idx in range(len(sampled_files), rows * cols):
            axes[idx].axis('off')

        plt.suptitle(f'Sample Images with Annotations ({split.upper()} split)',
                    fontsize=16, fontweight='bold', y=0.995)
        plt.tight_layout()
        plt.show()

    def run_full_eda(self):
        """
        Run complete EDA pipeline with all visualizations.
        """
        print("="*70)
        print("πŸ” YOLO DATASET EXPLORATORY DATA ANALYSIS")
        print("="*70)

        # 1. Dataset splits analysis
        print("\nπŸ“Š Analyzing dataset splits...")
        splits_data = self.analyze_dataset_splits()
        self.plot_split_statistics(splits_data)

        # 2. Image dimensions
        print("\nπŸ“ Analyzing image dimensions...")
        widths, heights, aspect_ratios = self.analyze_image_dimensions()
        self.plot_image_dimensions(widths, heights, aspect_ratios)

        # 3. Annotation heatmap
        print("\nπŸ—ΊοΈ  Creating annotation heatmap...")
        self.create_annotation_heatmap(grid_size=20)

        # 4. Objects per image
        print("\nπŸ“¦ Analyzing objects per image...")
        object_counts = self.analyze_objects_per_image()
        self.plot_objects_per_image(object_counts)

        # 5. Sample images
        print("\nπŸ–ΌοΈ  Visualizing sample images...")
        for split in ['train', 'val']:
            if split in self.config:
                print(f"\n   {split.upper()} samples:")
                self.visualize_sample_images(num_samples=6, split=split)

        print("\n" + "="*70)
        print("βœ… EDA COMPLETE!")
        print("="*70)
InΒ [Β ]:
# =======================
# USAGE
# =======================

# Initialize EDA
yaml_path = yaml_path  # Update this path
dataset_root = dataset_root  # Update if needed

eda = YOLODatasetEDA(yaml_path, dataset_root)

# Run full analysis
eda.run_full_eda()

# Or run individual analyses
# splits_data = eda.analyze_dataset_splits()
# eda.plot_split_statistics(splits_data)
# widths, heights, aspect_ratios = eda.analyze_image_dimensions()
# eda.plot_image_dimensions(widths, heights, aspect_ratios)
Output hidden; open in https://colab.research.google.com to view.